#pragma once
#include <vector>
#include "zMatrixBase.hpp"

// Sparse Matrix, still lack of operation

namespace zzz {
// It should no be derived from zMatrixBaseW, 
// since we don't want it to be writable as general zMatrixBaseW.
// Although it can, any operation will destroy the sparsity.
template<typename T, class Major=zColMajor>
class zSparseMatrix : public zMatrixBaseW<T,Major> 
{
public:
struct zSparseMatrixNode {
public:
  zSparseMatrixNode(zuint row, zuint col, const T& v):row_(row), col_(col), v_(v) {}
  zuint row_, col_;
  T v_;
};

  zSparseMatrix():zMatrixBaseW<T,Major>(0, 0), sorted_(SORT_NONE){}
  zSparseMatrix(zuint row, zuint col):zMatrixBaseW<T,Major>(row, col), sorted_(SORT_NONE){}
  zSparseMatrix(const zSparseMatrix<T,Major> &other)
  :zMatrixBaseW<T,Major>(other), sorted_(other.sorted_), v_(other.v_){}
  explicit zSparseMatrix(const zMatrixBaseR<T,Major> &other)
  :zMatrixBaseW<T,Major>(other.rows_, other.cols_), sorted_(SORT_COL) {
    for (zuint c=0; c<cols_; c++) for (zuint r=0; r<rows_; r++) {
      T v=other(r,c);
      if (v!=0) v_.push_back(zSparseMatrixNode(r, c, v));
    }
  }

  void SetSize(zuint row, zuint col) {
    Clear();
    rows_=row;
    cols_=col;
  }

  // Clear data only, do not change size
  void ClearData() {
    v_.clear();
  }

  // Get number of non-zero data cell
  zuint DataSize() const {
    return v_.size();
  }

  // Check if exist
  bool CheckExist(zuint row, zuint col) const {
    return GetConstNode(row, col)!=NULL;
  }

  // This is slow, O(N) time
  bool Remove(zuint row, zuint col) {
    zSparseMatrixNode* node = GetNode(row, col);
    if (!node) return false;
    v_.erase(node-v_.data()+v_.begin());
    return true;
  }

  const T operator()(zuint row, zuint col) const {
    return Get(row, col);
  }
  
  T& operator()(zuint row, zuint col) {
    return MustGet(row, col);
  }

  const zSparseMatrix<T,Major>& operator=(const zMatrixBaseR<T,Major> &other) {
    ZCHECK_EQ(rows_, other.rows_);
    ZCHECK_EQ(cols_, other.cols_);
    for (zuint i=0; i<rows_; i++) for (zuint j=0; j<cols_; j++) {
      T v=other(i,j);
      if (v!=0) AddData(i, j, v);
    }
    return *this;
  }

  const T Get(zuint row, zuint col) const {
    const zSparseMatrixNode* node = GetConstNode(row, col);
    if (node)
      return node->v_;
    else
      return ZERO_VALUE;
  }

  T& MustGet(zuint row, zuint col) {
    zSparseMatrixNode* node = GetNode(row, col);
    if (node)
      return node->v_;
    else {
      zSparseMatrixNode newnode(row, col, 0);
      if (sorted_==SORT_ROW && !v_.empty() && !RowMajorLess()(v_.back(), newnode)) sorted_=SORT_NONE;
      else if (sorted_==SORT_COL && !v_.empty() && !ColMajorLess()(v_.back(), newnode)) sorted_=SORT_NONE;
      v_.push_back(newnode);
      return v_.back().v_;
    }
  }

  void AddData(int row, int col, const T& number) {
    if (number!=0)
      MustGet(row, col) = number;
  }

  zSparseMatrixNode& GetNode(zuint i) {
    return v_[i];
  }

  const zSparseMatrixNode& GetNode(zuint i) const {
    return v_[i];
  }

  void ColMajorSort() {
    if (sorted_ == SORT_COL) return;
    sort(v_.begin(), v_.end(), ColMajorLess());
    sorted_ = SORT_COL;
  }

  void RowMajorSort() {
    if (sorted_ == SORT_ROW) return;
    sort(v_.begin(), v_.end(), RowMajorLess());
    sorted_ = SORT_ROW;
  }

private:
  struct RowMajorLess {
    bool operator()(const zSparseMatrixNode &a, const zSparseMatrixNode &b) {
      if (a.row_ < b.row_ || (a.row_==b.row_ && a.col_<b.col_))
        return true;
      return false;
    }
  };
  struct ColMajorLess {
    bool operator()(const zSparseMatrixNode &a, const zSparseMatrixNode &b) {
      if (a.col_ < b.col_ || (a.col_==b.col_ && a.row_<b.row_))
        return true;
      return false;
    }
  };
  struct RowColEqual {
    RowColEqual(zuint row, zuint col):row_(row), col_(col){}
    bool operator()(const zSparseMatrixNode &n) {
      return n.row_ == row_ && n.col_ == col_;
    }
    zuint row_, col_;
  };

  zSparseMatrixNode* GetNode(zuint row, zuint col) {
    vector<zSparseMatrixNode>::iterator vi;
    switch (sorted_) {
    case SORT_ROW:
      vi = std::lower_bound(v_.begin(), v_.end(), zSparseMatrixNode(row, col, 0), RowMajorLess());
      if (vi!=v_.end() && RowColEqual(row, col)(*vi)) return vi-v_.begin()+v_.data();
      return NULL;
    case SORT_COL:
      vi = std::lower_bound(v_.begin(), v_.end(), zSparseMatrixNode(row, col, 0), ColMajorLess());
      if (vi!=v_.end() && RowColEqual(row, col)(*vi)) return vi-v_.begin()+v_.data();
      return NULL;
    case SORT_NONE:
      vi = std::find_if(v_.begin(), v_.end(), RowColEqual(row, col));
      if (vi == v_.end()) return NULL;
      else return vi-v_.begin()+v_.data();
    default:
      ZCHECK(false)<<"Unknown sort type, this should not happen!";
    }
    return NULL;
  }

  const zSparseMatrixNode* GetConstNode(zuint row, zuint col) const {
    vector<zSparseMatrixNode>::const_iterator vi;
    switch (sorted_) {
    case SORT_ROW:
      vi = std::lower_bound(v_.begin(), v_.end(), zSparseMatrixNode(row, col, 0), RowMajorLess());
      if (vi!=v_.end() && RowColEqual(row, col)(*vi)) return vi-v_.begin()+v_.data();
      return NULL;
    case SORT_COL:
      vi = std::lower_bound(v_.begin(), v_.end(), zSparseMatrixNode(row, col, 0), ColMajorLess());
      if (vi!=v_.end() && RowColEqual(row, col)(*vi)) return vi-v_.begin()+v_.data();
      return NULL;
    case SORT_NONE:
      vi = std::find_if(v_.begin(), v_.end(), RowColEqual(row, col));
      if (vi == v_.end()) return NULL;
      else return vi-v_.begin()+v_.data();
    default:
      ZCHECK(false)<<"Unknown sort type, this should not happen!";
    }
    return NULL;
  }

  static T ZERO_VALUE;
  enum SortType {SORT_NONE, SORT_ROW, SORT_COL} sorted_;
  std::vector<zSparseMatrixNode> v_;
};

template<typename T, class Major>
T zSparseMatrix<T, Major>::ZERO_VALUE=T(0);

template<typename T, class Major>
zSparseMatrix<T, Major> ATA(zSparseMatrix<T, Major> &A)
{
  zSparseMatrix<T, Major> ret(A.Size(1), A.Size(1));
  ret.RowMajorSort();
  A.ColMajorSort();
  vector<pair<int,int> > col_start;
  int last_col = -1;
  for (zuint i=0; i<A.DataSize(); i++) {
    int this_col = A.GetNode(i).col_;
    if (this_col > last_col) {
      col_start.push_back(make_pair(i,this_col));
      last_col=this_col;
    }
  }
  col_start.push_back(make_pair(int(A.DataSize()),int(-1)));
  for (zuint line1=0;line1<col_start.size()-1;line1++) {
    for (zuint line2=0;line2<col_start.size()-1;line2++) {
      int pos1=col_start[line1].first, pos2=col_start[line2].first;
      int r=col_start[line1].second, c=col_start[line2].second;
      T v=0;
      while(pos1<col_start[line1+1].first && pos2<col_start[line2+1].first) {
        const zSparseMatrix<T, Major>::zSparseMatrixNode &node1=A.GetNode(pos1), &node2=A.GetNode(pos2);
        if (node1.row_ == node2.row_) {
          v += node1.v_ * node2.v_;
          pos1++; pos2++;
        } else if (node1.row_ > node2.row_)
          pos2++;
        else
          pos1++;
      }
      if (v!=0)
        ret.AddData(r, c, v);
    }
  }

  return ret;
}
};  // namespace zzz
